import baostock as bs
import pandas as pd
import os 
import akshare as ak


#####################
class Get_codes():

    def __init__(self, root=None) -> None:
        if root is None:
            root = r"data/stock_data/code_date"
        if not os.path.exists(root):
            os.makedirs(root)
        self.root = root

    def get_hs300_codes(self):
        rs = bs.query_hs300_stocks()
        stocks = []
        while (rs.error_code == "0") & rs.next():
            # 获取一条记录，将记录合并在一起
            stocks.append(rs.get_row_data())
        result = pd.DataFrame(stocks, columns=rs.fields)
        # 结果集输出到csv文件
        result.to_csv( os.path.join(self.root, "hs300_stocks.csv"), encoding="utf-8", index=False )
        print("result", result)

    def get_zz500_codes(self):
        rs = bs.query_zz500_stocks()
        stocks = []
        while (rs.error_code == "0") & rs.next():
            # 获取一条记录，将记录合并在一起
            stocks.append(rs.get_row_data())
        result = pd.DataFrame(stocks, columns=rs.fields)
        # 结果集输出到csv文件
        result.to_csv( os.path.join(self.root, "zz500_stocks.csv"), encoding="utf-8", index=False )
        print("result", result)

    def get_sz50_codes(self):
        rs = bs.query_sz50_stocks()
        stocks = []
        while (rs.error_code == "0") & rs.next():
            # 获取一条记录，将记录合并在一起
            stocks.append(rs.get_row_data())
        result = pd.DataFrame(stocks, columns=rs.fields)
        # 结果集输出到csv文件
        result.to_csv( os.path.join(self.root, "sz50_stocks.csv"), encoding="utf-8", index=False )
        print("result", result)
        
    def get_zz1000_codes(self):
        index_stock_cons_csindex_df = ak.index_stock_cons_csindex(symbol="000852")
        print(index_stock_cons_csindex_df)
        # df = pd.DataFrame()
        def id_code(serie):
            if '上海' in serie:
                code = 'sh.' 
            elif '深圳' in serie:
                code = 'sz.' 
            return code 
        df = index_stock_cons_csindex_df['交易所'].apply(id_code)
        zz1000 = df+ index_stock_cons_csindex_df['成分券代码']
        zz1000 = pd.DataFrame(zz1000, columns=['code'])
        zz1000.to_csv( os.path.join(self.root, "zz1000_stocks.csv"), encoding="utf-8", index=False )

